package com.example.demo.socket;

import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.json.JSONUtil;
import com.example.demo.model.Message;
import com.example.demo.rpc.Wav;
import com.example.demo.utils.PcmCovWavUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;
import org.vosk.DecoderUtil;
import org.vosk.Recognizer;

import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;

/**
 * Websocket 音频通讯
 *
 * @author yzd
 */
@Component
@ServerEndpoint(value = "/websocket/chat/audio", configurator = WsConfigurator.class)
public class AudioController extends BaseMediaController {

    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    private static final List<AbstractWsController> CONNECTIONS = new CopyOnWriteArrayList<>();

    /**
     * 创建等待队列
     * private static BlockingQueue<Runnable> blockingQueue = new ArrayBlockingQueue<>(50);
     * 创建一个单线执行任务，它可安排在给定延迟后运行命令或者定期地执行
     * corePoolSize -池中所保存的线程数，包括空闲线程
     * maximumPoolSize -池中允许的最大线程数
     * keepAliveTime -当线程数大于核心时，此为终止前多余的空闲线程等待新任务的最长时间
     * unit -keepAliveTime参数的时间单位
     * orkQueue -执行前用于保持任务的队列。此队列仅保持由execute方法提交的Runnable任务
     * private static ThreadPoolExecutor threadPoolExecutor =
     * new ThreadPoolExecutor(
     * 5, 100, 3
     * , TimeUnit.MILLISECONDS
     * , blockingQueue
     * , new NamedThreadFactory("thread", false)
     * , new ThreadPoolExecutor.CallerRunsPolicy());
     */
    private static ThreadPoolTaskExecutor threadPoolTaskExecutor;

    @Autowired
    public void setThreadPoolTaskExecutor(ThreadPoolTaskExecutor threadPoolTaskExecutor) {
        AudioController.threadPoolTaskExecutor = threadPoolTaskExecutor;
    }

    private static ConcurrentHashMap<String, Recognizer> recognizerConcurrentHashMap = new ConcurrentHashMap<>(30);
    /**
     * 并发容器 存储 字节临时缓冲区
     */
    private static ConcurrentHashMap<String, ByteArrayOutputStream> byteArrayOutputStreamConcurrentHashMap = new ConcurrentHashMap<>(30);

    @Override
    @OnOpen
    public void onOpen(Session session, EndpointConfig config) {

        super.onOpen(session, config);
        String userName = this.getUserName();
        recognizerConcurrentHashMap.put(userName, DecoderUtil.recognizer());
    }

    @Override
    @OnClose
    public void onClose() {
        try {
            String userName = this.getUserName();
            recognizerConcurrentHashMap.remove(userName);
            byteArrayOutputStreamConcurrentHashMap.remove(userName);
        } catch (Exception e) {
            e.printStackTrace();
        }
        super.onClose();
    }

    @Override
    @OnMessage(maxMessageSize = 10000000)
    public void onMessage(String message) {
        super.onMessage(message);
    }

    @Override
    @OnMessage(maxMessageSize = 10000000)
    public void onMessage(ByteBuffer message) {

        // 为空时 不处理
        if (ObjectUtil.isEmpty(message)) {
            return;
        }

        CompletableFuture.runAsync(() -> {

            String userName = this.getUserName();

            // 整体录音存储
            appendBuffer(message);


            // 临时 存储
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            byte[] bytes = message.array();
            try {
                byteArrayOutputStream.write(bytes);
            } catch (IOException e) {
                try {
                    byteArrayOutputStream.close();
                } catch (IOException ex) {
                    ex.printStackTrace();
                }
                e.printStackTrace();
                return;
            }

            // 内存处理 pcm转wav
            ByteArrayOutputStream pcmStream = new ByteArrayOutputStream();
            PcmCovWavUtil.convertWaveFile(byteArrayOutputStream, pcmStream);

            Wav waveFile = new Wav();
            waveFile.GetFromBytes(pcmStream.toByteArray());

            try {
                // 整体音频存储
                ByteArrayOutputStream stream = byteArrayOutputStreamConcurrentHashMap.get(userName);
                if (ObjectUtil.isNotEmpty(stream)) {

                    ByteArrayOutputStream tempStream = new ByteArrayOutputStream();
                    PcmCovWavUtil.convertWaveFile(stream, tempStream);

                    Path path = Paths.get("").toAbsolutePath().resolve("tempAudio");
                    String filePath = path.toAbsolutePath().toString() + File.separator;
                    new File(filePath).mkdirs();
                    FileOutputStream fileOutputStream = new FileOutputStream(new File(filePath + userName + ".wav"));
                    IoUtil.copy(new ByteArrayInputStream(tempStream.toByteArray()), fileOutputStream);


                    fileOutputStream.flush();
                    fileOutputStream.close();
                }
            } catch (Exception e) {
                e.printStackTrace();
            }

            // 转写 TODO 多线程会崩掉
            Recognizer recognizer = recognizerConcurrentHashMap.get(userName);
            recognizer.acceptWaveForm(pcmStream.toByteArray(), pcmStream.toByteArray().length);


            String partial = JSONUtil.parseObj(recognizer.getPartialResult()).getStr("partial");
            partial = partial.replaceAll("[FIL]", "，");

            if (partial.length() > 30 && partial.length() % 30 < 3) {
                partial = partial + "\n";
            }
            logger.error(partial);

            // 转写结果 发送
            Message msg = new Message(getUserName(), Message.MsgConstant.MSG_TO_ALL, partial);
            super.onMessage(msg.toString());

        }, threadPoolTaskExecutor);
    }

    /**
     * 并发容器 存储 字节临时缓冲区
     *
     * @param message ByteBuffer
     */
    private void appendBuffer(ByteBuffer message) {
        ByteArrayOutputStream byteArrayOutputStream = byteArrayOutputStreamConcurrentHashMap.get(this.getUserName());
        if (ObjectUtil.isEmpty(byteArrayOutputStream)) {
            byteArrayOutputStream = new ByteArrayOutputStream();
        }

        byte[] bytes = message.array();
        try {
            byteArrayOutputStream.write(bytes);
        } catch (IOException e) {
            try {
                byteArrayOutputStream.close();
            } catch (IOException ex) {
                ex.printStackTrace();
            }
            e.printStackTrace();
            return;
        }

        byteArrayOutputStreamConcurrentHashMap.put(this.getUserName(), byteArrayOutputStream);
    }


    @Override
    @OnError
    public void onError(Throwable t) {
        try {
            String userName = this.getUserName();
            recognizerConcurrentHashMap.remove(userName);
            byteArrayOutputStreamConcurrentHashMap.remove(userName);

        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    List<AbstractWsController> getConnections() {
        return CONNECTIONS;
    }

    @Override
    String getConnectType() {
        return "audio";
    }

}
